{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/profiling_tpus_in_colab.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N6ZDpd9XzFeN"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hleIN5-pcr0N"
      },
      "source": [
        "Copyright 2019-2020 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovFDeMgtjqW4"
      },
      "source": [
        "# Profiling TPUs in Colab\u0026nbsp; \u003ca href=\"https://cloud.google.com/tpu/\"\u003e\u003cimg valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"\u003e\u003c/a\u003e\n",
        "Adapted from [TPU colab example](https://colab.sandbox.google.com/notebooks/tpu.ipynb)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FjGRtgwPkK1k"
      },
      "source": [
        "## Overview\n",
        "This example works through training a model to classify images of\n",
        "flowers on Google's lightning-fast Cloud TPUs. Our model takes as input a photo of a flower and returns whether it is a daisy, dandelion, rose, sunflower, or tulip. A key objective of this colab is to show you how to set up and run TensorBoard, the program used for visualizing and analyzing program performance on Cloud TPU.\n",
        "\n",
        "This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File \u003e View on GitHub**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZESqCJ-PlV_R"
      },
      "source": [
        "## Instructions\n",
        "\n",
        "\u003ch3\u003e\u003ca href=\"https://cloud.google.com/tpu/\"\u003e\u003cimg valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"\u003e\u003c/a\u003e\u0026nbsp;\u0026nbsp;Train on TPU\u0026nbsp;\u0026nbsp; \u003c/h3\u003e\n",
        "\n",
        "* Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage.  Give yourself Storage Legacy Bucket Owner permission on the bucket.\n",
        "You will need to provide the bucket name when launching TensorBoard in the **Training** section.  \n",
        "\n",
        "Note: User input is required when launching and viewing TensorBoard, so do not use **Runtime \u003e Run all** to run through the entire colab. \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nRU4BUBT52Ns"
      },
      "source": [
        "## Authentication for connecting to GCS bucket for logging."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zxx_mPlE-saf"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "IS_COLAB_BACKEND = 'COLAB_GPU' in os.environ  # this is always set on Colab, the value is 0 or 1 depending on GPU presence\n",
        "if IS_COLAB_BACKEND:\n",
        "  from google.colab import auth\n",
        "  # Authenticates the Colab machine and also the TPU using your\n",
        "  # credentials so that they can access your private GCS buckets.\n",
        "  auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BqUw9YcT5_wl"
      },
      "source": [
        "## Updating tensorboard_plugin_profile"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZXJSIYyUMa1v"
      },
      "outputs": [],
      "source": [
        "!pip install -U pip install -U tensorboard_plugin_profile==2.3.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_pQCOmISAQBu"
      },
      "source": [
        "## Enabling and testing the TPU\n",
        "\n",
        "First, you'll need to enable TPUs for the notebook:\n",
        "\n",
        "- Navigate to Edit→Notebook Settings\n",
        "- select TPU from the Hardware Accelerator drop-down\n",
        "\n",
        "Next, we'll check that we can connect to the TPU:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FpvUOuC3j27n"
      },
      "outputs": [],
      "source": [
        "%tensorflow_version 2.x\n",
        "import tensorflow as tf\n",
        "print(\"Tensorflow version \" + tf.__version__)\n",
        "\n",
        "try:\n",
        "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection\n",
        "  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
        "except ValueError:\n",
        "  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')\n",
        "\n",
        "tf.config.experimental_connect_to_cluster(tpu)\n",
        "tf.tpu.experimental.initialize_tpu_system(tpu)\n",
        "tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vFIMfPmgQa0h"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import numpy as np\n",
        "from matplotlib import pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kvPXiovhi3ZZ"
      },
      "source": [
        "\n",
        "## Input data\n",
        "\n",
        "Our input data is stored on Google Cloud Storage. To more fully use the parallelism TPUs offer us, and to avoid bottlenecking on data transfer, we've stored our input data in TFRecord files, 230 images per file.\n",
        "\n",
        "Below, we make heavy use of `tf.data.experimental.AUTOTUNE` to optimize different parts of input loading.\n",
        "\n",
        "All of these techniques are a bit overkill for our (small) dataset, but demonstrate best practices for using TPUs.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LtAVr-4CP1rp"
      },
      "outputs": [],
      "source": [
        "AUTO = tf.data.experimental.AUTOTUNE\n",
        "\n",
        "IMAGE_SIZE = [331, 331]\n",
        "\n",
        "batch_size = 16 * tpu_strategy.num_replicas_in_sync\n",
        "\n",
        "gcs_pattern = 'gs://flowers-public/tfrecords-jpeg-331x331/*.tfrec'\n",
        "validation_split = 0.19\n",
        "filenames = tf.io.gfile.glob(gcs_pattern)\n",
        "split = len(filenames) - int(len(filenames) * validation_split)\n",
        "train_fns = filenames[:split]\n",
        "validation_fns = filenames[split:]\n",
        "        \n",
        "def parse_tfrecord(example):\n",
        "  features = {\n",
        "    \"image\": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring\n",
        "    \"class\": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar\n",
        "    \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n",
        "  }\n",
        "  example = tf.io.parse_single_example(example, features)\n",
        "  decoded = tf.image.decode_jpeg(example['image'], channels=3)\n",
        "  normalized = tf.cast(decoded, tf.float32) / 255.0 # convert each 0-255 value to floats in [0, 1] range\n",
        "  image_tensor = tf.reshape(normalized, [*IMAGE_SIZE, 3])\n",
        "  one_hot_class = tf.reshape(tf.sparse.to_dense(example['one_hot_class']), [5])\n",
        "  return image_tensor, one_hot_class\n",
        "\n",
        "def load_dataset(filenames):\n",
        "  # Read from TFRecords. For optimal performance, we interleave reads from multiple files.\n",
        "  records = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)\n",
        "  return records.map(parse_tfrecord, num_parallel_calls=AUTO)\n",
        "\n",
        "def get_training_dataset():\n",
        "  dataset = load_dataset(train_fns)\n",
        "\n",
        "  # Create some additional training images by randomly flipping and\n",
        "  # increasing/decreasing the saturation of images in the training set. \n",
        "  def data_augment(image, one_hot_class):\n",
        "    modified = tf.image.random_flip_left_right(image)\n",
        "    modified = tf.image.random_saturation(modified, 0, 2)\n",
        "    return modified, one_hot_class\n",
        "  augmented = dataset.map(data_augment, num_parallel_calls=AUTO)\n",
        "\n",
        "  # Prefetch the next batch while training (autotune prefetch buffer size).\n",
        "  return augmented.repeat().shuffle(2048).batch(batch_size).prefetch(AUTO) \n",
        "\n",
        "training_dataset = get_training_dataset()\n",
        "validation_dataset = load_dataset(validation_fns).batch(batch_size).prefetch(AUTO)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KTukaIGil7_m"
      },
      "source": [
        "Let's take a peek at the training dataset we've created:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-DwfsJq0l_DJ"
      },
      "outputs": [],
      "source": [
        "CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']\n",
        "\n",
        "def display_one_flower(image, title, subplot, color):\n",
        "  plt.subplot(subplot)\n",
        "  plt.axis('off')\n",
        "  plt.imshow(image)\n",
        "  plt.title(title, fontsize=16, color=color)\n",
        "  \n",
        "# If model is provided, use it to generate predictions.\n",
        "def display_nine_flowers(images, titles, title_colors=None):\n",
        "  subplot = 331\n",
        "  plt.figure(figsize=(13,13))\n",
        "  for i in range(9):\n",
        "    color = 'black' if title_colors is None else title_colors[i]\n",
        "    display_one_flower(images[i], titles[i], 331+i, color)\n",
        "  plt.tight_layout()\n",
        "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
        "  plt.show()\n",
        "\n",
        "def get_dataset_iterator(dataset, n_examples):\n",
        "  return dataset.unbatch().batch(n_examples).as_numpy_iterator()\n",
        "\n",
        "training_viz_iterator = get_dataset_iterator(training_dataset, 9)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JTnyd7qfbYr4"
      },
      "outputs": [],
      "source": [
        "# Re-run this cell to show a new batch of images\n",
        "images, classes = next(training_viz_iterator)\n",
        "class_idxs = np.argmax(classes, axis=-1) # transform from one-hot array to class number\n",
        "labels = [CLASSES[idx] for idx in class_idxs]\n",
        "display_nine_flowers(images, labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ALtRUlxhw8Vt"
      },
      "source": [
        "## Model\n",
        "To get maxmimum accuracy, we leverage a pretrained image recognition model (here, [Xception](http://openaccess.thecvf.com/content_cvpr_2017/papers/Chollet_Xception_Deep_Learning_CVPR_2017_paper.pdf)). We drop the ImageNet-specific top layers (`include_top=false`), and add a max pooling and a softmax layer to predict our 5 classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hJl3vNtJOB-x"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  pretrained_model = tf.keras.applications.Xception(input_shape=[*IMAGE_SIZE, 3], include_top=False)\n",
        "  pretrained_model.trainable = True\n",
        "  model = tf.keras.Sequential([\n",
        "    pretrained_model,\n",
        "    tf.keras.layers.GlobalAveragePooling2D(),\n",
        "    tf.keras.layers.Dense(5, activation='softmax')\n",
        "  ])\n",
        "  model.compile(\n",
        "    optimizer='adam',\n",
        "    loss = 'categorical_crossentropy',\n",
        "    metrics=['accuracy']\n",
        "  )\n",
        "  return model\n",
        "\n",
        "with tpu_strategy.scope(): # creating the model in the TPUStrategy scope means we will train the model on the TPU\n",
        "  model = create_model()\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dMfenMQcxAAb"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N0wDLhZOdjAA"
      },
      "source": [
        "Calculate the number of images in each dataset. Rather than actually load the data to do so (expensive), we rely on hints in the filename. This is used to calculate the number of batches per epoch.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M-ID7vP5mIKs"
      },
      "outputs": [],
      "source": [
        "def count_data_items(filenames):\n",
        "  # The number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items\n",
        "  n = [int(re.compile(r\"-([0-9]*)\\.\").search(filename).group(1)) for filename in filenames]\n",
        "  return np.sum(n)\n",
        "\n",
        "n_train = count_data_items(train_fns)\n",
        "n_valid = count_data_items(validation_fns)\n",
        "train_steps = count_data_items(train_fns) // batch_size\n",
        "print(\"TRAINING IMAGES: \", n_train, \", STEPS PER EPOCH: \", train_steps)\n",
        "print(\"VALIDATION IMAGES: \", n_valid)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eqYT2OcmdnBu"
      },
      "source": [
        "Calculate and show a learning rate schedule. We start with a fairly low rate, as we're using a pre-trained model and don't want to undo all the fine work put into training it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iXJXG-Ufdbnu"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 12\n",
        "\n",
        "start_lr = 0.00001\n",
        "min_lr = 0.00001\n",
        "max_lr = 0.00005 * tpu_strategy.num_replicas_in_sync\n",
        "rampup_epochs = 5\n",
        "sustain_epochs = 0\n",
        "exp_decay = .8\n",
        "\n",
        "def lrfn(epoch):\n",
        "  if epoch \u003c rampup_epochs:\n",
        "    return (max_lr - start_lr)/rampup_epochs * epoch + start_lr\n",
        "  elif epoch \u003c rampup_epochs + sustain_epochs:\n",
        "    return max_lr\n",
        "  else:\n",
        "    return (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr\n",
        "    \n",
        "lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True)\n",
        "\n",
        "rang = np.arange(EPOCHS)\n",
        "y = [lrfn(x) for x in rang]\n",
        "plt.plot(rang, y)\n",
        "print('Learning rate per epoch:')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zs2tsV8xebhU"
      },
      "source": [
        "Train the model. While the first epoch will be quite a bit slower as we must XLA-compile the execution graph and load the data, later epochs should complete in ~5s."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vsxeGPJeyb5v"
      },
      "outputs": [],
      "source": [
        "# Load the TensorBoard notebook extension.\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jTrfSGu38TLd"
      },
      "outputs": [],
      "source": [
        "# Get TPU profiling service address. This address will be needed for capturing\n",
        "# profile information with TensorBoard in the following steps.\n",
        "service_addr = tpu.get_master().replace(':8470', ':8466')\n",
        "print(service_addr)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mNA__vniyY8e"
      },
      "outputs": [],
      "source": [
        "# Launch TensorBoard.\n",
        "%tensorboard --logdir=gs://bucket-name  # Replace the bucket-name variable with your own gcs bucket"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "erj2-mPpvqbG"
      },
      "source": [
        "The TensorBoard UI is displayed in a browser window. In this colab, perform the following steps to prepare to capture profile information.\n",
        "1.  Click on the dropdown menu box on the top right side and scroll down and click PROFILE. A new window appears that shows: **No profile data was found** at the top.\n",
        "1.  Click on the CAPTURE PROFILE button. A new dialog appears. The top input line shows: **Profile Service URL or TPU name**. Copy and paste the Profile Service URL (the service_addr value shown before launching TensorBoard) into the top input line. While still on the dialog box, start the training with the next step.\n",
        "1.  Click on the next colab cell to start training the model.\n",
        "1.  Watch the output from the training until several epochs have completed. This allows time for the profile data to start being collected. Return to the dialog box and click on the CAPTURE button. If the capture succeeds, the page will auto refresh and redirect you to the profiling results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Qv8rC4aVOFB"
      },
      "outputs": [],
      "source": [
        "history = model.fit(training_dataset, validation_data=validation_dataset,\n",
        "                    steps_per_epoch=train_steps, epochs=EPOCHS, callbacks=[lr_callback])\n",
        "\n",
        "final_accuracy = history.history[\"val_accuracy\"][-5:]\n",
        "print(\"FINAL ACCURACY MEAN-5: \", np.mean(final_accuracy))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VngeUBIdyJ1T"
      },
      "outputs": [],
      "source": [
        "def display_training_curves(training, validation, title, subplot):\n",
        "  ax = plt.subplot(subplot)\n",
        "  ax.plot(training)\n",
        "  ax.plot(validation)\n",
        "  ax.set_title('model '+ title)\n",
        "  ax.set_ylabel(title)\n",
        "  ax.set_xlabel('epoch')\n",
        "  ax.legend(['training', 'validation'])\n",
        "\n",
        "plt.subplots(figsize=(10,10))\n",
        "plt.tight_layout()\n",
        "display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)\n",
        "display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o19yp4TxqmgI"
      },
      "source": [
        "Accuracy goes up and loss goes down. Looks good!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a_rjVo-RAoYd"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "More TPU/Keras examples include:\n",
        "- [Shakespeare in 5 minutes with Cloud TPUs and Keras](https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb)\n",
        "- [Fashion MNIST with Keras and TPUs](https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/fashion_mnist.ipynb)\n",
        "\n",
        "We'll be sharing more examples of TPU use in Colab over time, so be sure to check back for additional example links, or [follow us on Twitter @GoogleColab](https://twitter.com/googlecolab)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "history_visible": true,
      "name": "Profiling TPUs in Colab",
      "provenance": [
        {
          "file_id": "/piper/depot/google3/third_party/cloud_tpu/colab/Profiling_TPUs_in_Colab.ipynb?workspaceId=gmadrone:profiling-tpus::citc",
          "timestamp": 1606874904160
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
